import torch
from torch.utils.data import Dataset, DataLoader
from transformers import RobertaForSequenceClassification, AdamW

import pandas as pd
from tqdm.auto import tqdm
from sklearn.model_selection import KFold
from sklearn.metrics import mean_squared_error
from scipy.stats import pearsonr, spearmanr
import random
import numpy as np
from sklearn.utils import check_random_state

import gc
import copy
import os

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
random.seed(42)
rng = check_random_state(42)

import pickle
with open("/home/projects/home/SKMT/STS/encodings/encodings_STS_SKMT", 'rb') as f:
    encodings = pickle.load(f)

class STSDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return self.encodings['input_ids'].shape[0]

    def __getitem__(self, idx):
        item = {key: tensor[idx].clone().detach() for key, tensor in self.encodings.items()}
        return {
            'input_ids': item['input_ids'],
            'attention_mask': item['attention_mask'],
            'labels': item['labels']
        }

# Fine-tuning model
num_epochs = 15
patience = 3  # počet epoch na skoré zastavenie
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

kf = KFold(n_splits=10, shuffle=True, random_state=42)
all_fold_results = []

best_mse = float('inf')

columns = ['fold', 'epoch', 'train_loss', 'eval_loss', 'mse', 'pearson', 'spearman']
df = pd.DataFrame(columns=columns)

for fold, (train_indices, test_indices) in enumerate(kf.split(encodings['input_ids'])):
    local_model_path = "/home/projects/home/SKMT/models/SKMT_epoch_9_encodings_2"
    model = RobertaForSequenceClassification.from_pretrained(local_model_path, num_labels=1)
    
    optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.01)
    fold_results = {'fold': fold + 1}
    
    train_encodings = {
        'input_ids': encodings['input_ids'][train_indices].clone(), 
        'attention_mask': encodings['attention_mask'][train_indices].clone(), 
        'labels': encodings['labels'][train_indices].clone(), 
    }

    val_encodings = {
        'input_ids': encodings['input_ids'][test_indices].clone(), 
        'attention_mask': encodings['attention_mask'][test_indices].clone(), 
        'labels': encodings['labels'][test_indices].clone(), 
    }
    
    train_dataset = STSDataset(train_encodings)
    val_dataset = STSDataset(val_encodings)

    train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=True, num_workers=0)
    
    best_val_loss = float('inf')
    best_mse = float('inf')
    best_pearson = -float('inf')
    best_spearman = -float('inf')
    patience_counter = 0
    
    model.to(device)
    
    for epoch in range(num_epochs):
        model.train()
        train_loss = 0.0
        loop = tqdm(enumerate(train_dataloader), total=len(train_dataloader), leave=True)
        for step, batch in loop:
            optimizer.zero_grad()
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device).float().unsqueeze(1)
            
            outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            loop.set_description(f"Fold {fold+1} - Epoch {epoch+1}/{num_epochs} - Train")
            loop.set_postfix(loss=loss.item())
            
            train_loss += loss.item()
        avg_train_loss = train_loss / len(train_dataloader)
        
        model.eval()
        val_loss = 0.0
        y_true = []
        y_pred = []

        loop = tqdm(enumerate(val_dataloader), total=len(val_dataloader), leave=True)
        for step, batch in loop:
            with torch.no_grad():
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device).float().unsqueeze(1)
                outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
                loss = outputs.loss
                
                loop.set_description(f"Fold {fold+1} - Epoch {epoch+1}/{num_epochs} - Validation")
                loop.set_postfix(loss=loss.item())
                
                val_loss += loss.item()

                logits = outputs.logits
                y_pred.extend(logits.cpu().numpy().flatten())
                y_true.extend(labels.cpu().numpy().flatten())

        avg_val_loss = val_loss / len(val_dataloader)
        mse = mean_squared_error(y_true, y_pred)
        pearson_corr, _ = pearsonr(y_true, y_pred)
        spearman_corr, _ = spearmanr(y_true, y_pred)

        # Výpis informácií o priemerových hodnotách
        print(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {avg_train_loss:.4f}, Val Loss: {avg_val_loss:.4f}, MSE: {mse:.4f}, Pearson: {pearson_corr:.4f}, Spearman: {spearman_corr:.4f}")
                
        new_row = {
            'fold': fold + 1,
            'epoch': epoch + 1,
            'train_loss': avg_train_loss,
            'eval_loss': avg_val_loss,
            'mse': mse,
            'pearson': pearson_corr,
            'spearman': spearman_corr,
        }

        new_df = pd.DataFrame([new_row])
        df = pd.concat([df, new_df], ignore_index=True)
        df.to_excel(f"/home/projects/home/SKMT/STS/results/m9/SKMT_folds_loss_epoch_{epoch+1}.xlsx")

        if mse < best_mse:
            # best_model = copy.deepcopy(model)
            best_mse = mse
            best_pearson = pearson_corr
            best_spearman = spearman_corr
            patience_counter = 0
        else:
            patience_counter += 1
            if patience_counter > patience:
                print(f"Early stopping on epoch {epoch + 1}")
                break

        gc.collect()

    # Uloženie modelu pre aktuálny fold
    # model_save_path = f'/home/projects/home/SKMT/STS/models/SKMT_fold_{fold+1}'
    # if not os.path.exists(model_save_path):
    #     os.makedirs(model_save_path)
    # best_model.save_pretrained(model_save_path)
    
df.to_excel("/home/projects/home/SKMT/STS/results/m9/SKMT_folds_loss_final.xlsx")
